package com.three.db.config;

import com.dangdang.ddframe.rdb.sharding.api.ShardingDataSourceFactory;
import com.dangdang.ddframe.rdb.sharding.api.rule.DataSourceRule;
import com.dangdang.ddframe.rdb.sharding.api.rule.ShardingRule;
import com.dangdang.ddframe.rdb.sharding.api.rule.TableRule;
import com.dangdang.ddframe.rdb.sharding.api.strategy.database.DatabaseShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.database.DatabaseShardingStrategy;
import com.dangdang.ddframe.rdb.sharding.api.strategy.database.NoneDatabaseShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.database.SingleKeyDatabaseShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.NoneTableShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.SingleKeyTableShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.TableShardingAlgorithm;
import com.dangdang.ddframe.rdb.sharding.api.strategy.table.TableShardingStrategy;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.three.constant.Separator;
import com.three.db.annotation.Sharding;
import com.three.db.annotation.Shardkey;
import com.three.db.schema.TableSchema;
import com.three.db.shard.base.ShardingType;
import com.three.db.shard.modulo.ModuloDbShardingAlgorithm;
import com.three.db.shard.modulo.ModuloTbShardingAlgorithm;
import com.three.db.shard.time.TimeDbShardingAlgorithm;
import com.three.db.shard.time.TimeTbShardingAlgorithm;
import com.three.utils.LogUtils;
import com.three.utils.PackageScanner;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.StringUtils;
import org.springframework.boot.autoconfigure.jdbc.DataSourceBuilder;
import org.springframework.context.EnvironmentAware;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.DependsOn;
import org.springframework.core.env.Environment;

import javax.annotation.PostConstruct;
import javax.persistence.Table;
import javax.sql.DataSource;
import java.lang.reflect.Constructor;
import java.lang.reflect.InvocationTargetException;
import java.time.LocalDate;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * <p>数据源配置，
 * 支持异构数据源</p>
 * Created by ziqingwang on 2017/5/18 0018.
 */
@Configuration
@DependsOn("packageScanner")
public class DataSourceConfig implements EnvironmentAware {
    private static Environment environment;
    private static String URL = "url";
    private static String USER_NAME = "username";
    private static String PASS_WORD = "password";
    private static String DRIVER_CLASS_NAME = "driver-class-name";
    private static String DATASOURCE_PREFIX = "chess.db";
    private static String DATASOURCE_ORDINARY = "chess.db.ordinary";
    private static String DATASOURCE_SHARDING = "chess.db.sharding";
    private static String DEFAULT = "default";
    private static final Map<String, DataSource> DATA_SOURCE_MAP = Maps.newHashMap();
    public static final Map<String,TableSchema> SHARDING_TABLE_NAME_MAP=Maps.newHashMap();
    @Override
    public void setEnvironment(Environment environment) {
        this.environment = environment;
    }

    @PostConstruct
    public void loadDataSource() {
        LogUtils.Console.warn("——————————-db 配置 start——————————");
        loadOrdinaryDataSource();
        loadShardingDataSource();
        LogUtils.Console.warn("——————————-db 配置 end——————————");
    }

    public DataSource getDataSource(String key) {
        return DATA_SOURCE_MAP.get(DATASOURCE_PREFIX + Separator.POINT + key);
    }

    //加载普通数据源
    private void loadOrdinaryDataSource() {
        String ordinaryDbs=environment.getProperty(DATASOURCE_ORDINARY);
        if(StringUtils.isEmpty(ordinaryDbs)){
            LogUtils.Console.warn("db——无普通数据源加载!");
            return;
        }
        Arrays.stream(ordinaryDbs.split(Separator.COMMA))
                .forEach(k -> {
                    LogUtils.Console.warn("db——加载普通数据源：{}",k);
                    creatOrdinaryDataSource(DATASOURCE_PREFIX, k);
                });
    }

    //加载分片数据源
    private void loadShardingDataSource() {
        String shardingDbs=environment.getProperty(DATASOURCE_SHARDING);
        if(StringUtils.isEmpty(shardingDbs)){
            LogUtils.Console.warn("db——无分片数据源加载!");
            return;
        }
        Arrays.stream(shardingDbs.split(Separator.COMMA))
                .forEach(k -> {
                    LogUtils.Console.warn("db——加载分片数据源：{}",k);
                    creatShardingDataSource(k);
                });
    }

    //创建分片数据源
    public void creatShardingDataSource(String dbKey) {
        DataSource dataSource = null;
        //查找需要分库分表的实体
        Set<Class> classSet = PackageScanner.clazzCollection.stream()
                .filter(clazz -> {
                    Sharding sharding = clazz.getAnnotation(Sharding.class);
                    Table table=clazz.getAnnotation(Table.class);
                    return Objects.nonNull(sharding)&&Objects.nonNull(table) ?
                            sharding.db_name().equals(dbKey) : false;
                })
                .collect(Collectors.toSet());
        if (CollectionUtils.isEmpty(classSet)) {
            LogUtils.Console.warn("没有分片规则，数据库:{}", dbKey);
            return;
        }
        List<TableRule> tb_rule_list = new ArrayList<>(classSet.size());
        Map<String, DataSource> db_maps = Maps.newHashMap();
        classSet.stream().forEach(
                (Class clazz) -> {
                    Sharding sharding = (Sharding) clazz.getAnnotation(Sharding.class);
                    Table table = (Table) clazz.getAnnotation(Table.class);
                    if (Objects.isNull(sharding) || Objects.isNull(table)) return;
                    String tb_name = table.name();
                    int db_num = sharding.shard_db_num();
                    int tb_num = sharding.shard_tb_num();
                    ShardingType shard_db_type = sharding.shard_db_type();
                    ShardingType shard_tb_type = sharding.shard_tb_type();
                    String shard_db_name = sharding.shard_db_name();
                    String shard_tb_name = sharding.shard_tb_name();
                    Class<? extends DatabaseShardingAlgorithm> shard_db_clazz = sharding.shard_db_clazz();
                    Class<? extends TableShardingAlgorithm> shard_tb_clazz = sharding.shard_tb_clazz();
                    ChronoUnit shard_db_unit = sharding.shard_db_unit();
                    ChronoUnit shard_tb_unit = sharding.shard_tb_unit();
                    Map<Shardkey.TYPE, String> shard_key_map = Arrays.stream(clazz.getDeclaredFields())
                            .filter(field -> {
                                return Objects.nonNull(field.getAnnotation(Shardkey.class));
                            })
                            .collect(Collectors.toMap(field -> {
                                        Shardkey shardkey = field.getAnnotation(Shardkey.class);
                                        return shardkey.type();
                                    }, field -> field.getName()
                            ));
                    String shard_db_key = StringUtils.isEmpty(shard_key_map.get(Shardkey.TYPE.ALL)) ?
                            shard_key_map.get(Shardkey.TYPE.DB) : shard_key_map.get(Shardkey.TYPE.ALL);
                    String shard_tb_key = StringUtils.isEmpty(shard_key_map.get(Shardkey.TYPE.ALL)) ?
                            shard_key_map.get(Shardkey.TYPE.TB) : shard_key_map.get(Shardkey.TYPE.ALL);
                    Set<String> shardDbNameSet = getShardingDbOrTbName("db", shard_db_type, shard_db_name, db_num, shard_db_unit);
                    Set<String> shardTbNameSet = getShardingDbOrTbName(tb_name, shard_tb_type, shard_tb_name, tb_num, shard_tb_unit);
                    shardTbNameSet=converDbOrTbName(1,tb_name,shardTbNameSet);
                    //保存命名，供数据库结构更新使用
                    TableSchema shardingTable=new TableSchema();
                    shardingTable.setShardingTable(shardTbNameSet);
                    shardingTable.setShardingDataBase(converDbOrTbName(0,dbKey,shardDbNameSet));
                    SHARDING_TABLE_NAME_MAP.put(tb_name,shardingTable);
                    //创建数据库
                    Map<String, DataSource> dataSourceMap = shardDbNameSet.stream()
                            .collect(
                                    Collectors.toMap(
                                            db_key -> {
                                                return dbKey + Separator.UNDERLINE + db_key;
                                            }
                                            , db_key -> {
                                                return creatOrdinaryDataSource(DATASOURCE_PREFIX + Separator.POINT + dbKey, db_key);
                                            }));
                    db_maps.putAll(dataSourceMap);
                    DataSourceRule db_rule = new DataSourceRule(dataSourceMap, dataSourceMap.keySet().iterator().next());
                    try {
                        SingleKeyDatabaseShardingAlgorithm databaseShardingAlgorithm = null;
                        SingleKeyTableShardingAlgorithm tableShardingAlgorithm = null;
                        switch (shard_db_type) {
                            case TIME:
                                shard_db_clazz = TimeDbShardingAlgorithm.class;
                                try {
                                    Constructor constructor = shard_db_clazz.getConstructor(ChronoUnit.class);
                                    databaseShardingAlgorithm = (SingleKeyDatabaseShardingAlgorithm) constructor.newInstance(shard_db_unit);
                                } catch (NoSuchMethodException | InvocationTargetException e) {
                                    LogUtils.Console.error("数据库分片：分库规则：{}，timeUnit:{}", shard_db_clazz, shard_db_unit);
                                }
                                break;
                            case MOLD:
                                shard_db_clazz = ModuloDbShardingAlgorithm.class;
                                databaseShardingAlgorithm = ((Class<SingleKeyDatabaseShardingAlgorithm>) shard_db_clazz).newInstance();
                                break;
                            case SELF:
                                databaseShardingAlgorithm = ((Class<SingleKeyDatabaseShardingAlgorithm>) shard_db_clazz).newInstance();
                                break;
                            case NONE:
                                shard_db_key="";
                                databaseShardingAlgorithm=new NoneDatabaseShardingAlgorithm();
                                break;
                            default:
                                databaseShardingAlgorithm = ((Class<SingleKeyDatabaseShardingAlgorithm>) shard_db_clazz).newInstance();
                                break;
                        }
                        switch (shard_tb_type) {
                            case TIME:
                                shard_tb_clazz = TimeTbShardingAlgorithm.class;
                                try {
                                    Constructor constructor = shard_tb_clazz.getConstructor(ChronoUnit.class);
                                    tableShardingAlgorithm = (SingleKeyTableShardingAlgorithm) constructor.newInstance(shard_tb_unit);
                                } catch (NoSuchMethodException | InvocationTargetException e) {
                                    LogUtils.Console.error("数据库分片：分表规则：{}，timeUnit:{}", shard_tb_clazz, shard_tb_unit);
                                }
                                break;
                            case MOLD:
                                shard_tb_clazz = ModuloTbShardingAlgorithm.class;
                                tableShardingAlgorithm = ((Class<SingleKeyTableShardingAlgorithm>) shard_tb_clazz).newInstance();
                                break;
                            case SELF:
                                tableShardingAlgorithm = ((Class<SingleKeyTableShardingAlgorithm>) shard_tb_clazz).newInstance();
                                break;
                            case NONE:
                                shard_tb_key="";
                                tableShardingAlgorithm= new NoneTableShardingAlgorithm();
                                break;
                            default:
                                tableShardingAlgorithm = ((Class<SingleKeyTableShardingAlgorithm>) shard_tb_clazz).newInstance();
                                break;
                        }
                        TableRule.TableRuleBuilder tableRuleBuilder=TableRule.builder(table.name())
                                .dataSourceRule(db_rule)
                                .actualTables(Lists.newArrayList(shardTbNameSet))
                                .databaseShardingStrategy(new DatabaseShardingStrategy(shard_db_key, databaseShardingAlgorithm))
                                .tableShardingStrategy(new TableShardingStrategy(shard_tb_key, tableShardingAlgorithm));
                        tb_rule_list.add(tableRuleBuilder.build());
                    } catch (IllegalAccessException | InstantiationException e) {
                        LogUtils.Console.error("分库分表异常：{}",e);
                    }
                }
        );
        ShardingRule shardingRule = ShardingRule.builder()
                .dataSourceRule(new DataSourceRule(db_maps, Objects.isNull(db_maps.get(DEFAULT))?db_maps.keySet().iterator().next():DEFAULT))
                .tableRules(tb_rule_list)
                .build();
        dataSource = ShardingDataSourceFactory.createDataSource(shardingRule);
        DATA_SOURCE_MAP.put(DATASOURCE_PREFIX + Separator.POINT + dbKey, dataSource);
    }

    //获取命名
    private Set<String> getShardingDbOrTbName(String prefix, ShardingType shard_type, String shard_names, int num, ChronoUnit shard_unit) {
        Set<String> result = Sets.newHashSet();
        switch (shard_type) {
            //求模
            case MOLD:
                result = IntStream.range(0, num)
                        .mapToObj(i -> {
                            return prefix + Separator.UNDERLINE + i;
                        })
                        .collect(Collectors.toSet());
                break;
            //时间（先提供按月吧）
            case TIME:
                String year = String.valueOf(LocalDate.now().getYear());
                if (shard_unit.equals(ChronoUnit.MONTHS)) {
                    result = IntStream.range(1, 13)
                            .mapToObj(i -> {
                                return prefix + Separator.UNDERLINE + year + (i >= 10 ? i : "0" + i);
                            })
                            .collect(Collectors.toSet());
                }
                break;
            //自定义
            case SELF:
                result = Sets.newHashSet(StringUtils.split(shard_names, Separator.COMMA));
                break;
            case NONE:
                result.add(DEFAULT);
                break;
            default:
                result = Sets.newHashSet(StringUtils.split(shard_names, Separator.COMMA));
                break;
        }
        return result;
    }

    //创建普通数据源
    private DataSource creatOrdinaryDataSource(String prefix, String key) {
        String key_ = key.equals(DEFAULT)?
                prefix:prefix + Separator.POINT + key;
        DataSource dataSource = DATA_SOURCE_MAP.get(key_);
        if (Objects.isNull(dataSource)) {
            dataSource = DataSourceBuilder.create()
                    .url(environment.getProperty(key_  + Separator.POINT + URL))
                    .username(environment.getProperty(key_   + Separator.POINT + USER_NAME))
                    .password(environment.getProperty(key_   + Separator.POINT + PASS_WORD))
                    .driverClassName(environment.getProperty(key_ + Separator.POINT + DRIVER_CLASS_NAME))
                    .build();
            DATA_SOURCE_MAP.put(key_, dataSource);
        }
        return dataSource;
    }

    /**
     * 转换下命名，转成数据库全命名
     * @param type 0：DB 1：TB
     * @param logicName 逻辑命名
     * @param nameSet 分片命名
     * @return
     */
    private Set<String> converDbOrTbName(int type,String logicName,Set<String> nameSet){
        Set<String> result=Sets.newHashSet();
        if(type==0){//db没拼逻辑名，传入的是db_0,db_1...
            result = nameSet.stream()
                    .map(db -> {
                        if (db.equals(DEFAULT)) return logicName;
                        return logicName + Separator.UNDERLINE + db;
                    })
                    .collect(Collectors.toSet());
        }else if(type==1){//tb传入的加了逻辑名，传入t_demo_201701,但要处理默认表
            result = nameSet.stream()
                    .map(tb -> {
                        if (tb.equals(DEFAULT)) return logicName;
                        return tb;
                    })
                    .collect(Collectors.toSet());
        }
        return result;
    }

}
